# CelebA image generation using Conditional DCGAN
import copy
import json
import os
import random
from tqdm import tqdm


import numpy as np
import torch
import pickle

from Causal_Partial_Mnist.CausalGraph_Mnist import set_mnist_nonId_newgraph
# import torch.nn.functional as F
# from matplotlib import pyplot as plt
# from Alignment_Project.alignmentEvaluation import alignmentEvaluation
# from Asia_Modular_Training.asiaEvaluation import asiaEvaluation
# from Benchmarks_Training.Backdoor_Traininig.backdoorEvaluation import backdoorEvaluation
# from Benchmarks_Training.Frontdoor_Training.frontdoorEvaluation import frontdoorEvaluation
# from Benchmarks_Training.TransportGraph_Traininig.TransportEvaluation import transportEvaluation
# from CausalMNISTAddition.DigitImageGeneration.mnist_image_generation import plot_trained_digits
# from CausalSachs.GAN_Evaluation.sachsEvaluation import sachsEvaluation


from Image_Mediator_Training.Image_Mediator_Evaluation import imageMediatorEvaluation
from ModularUtils.ControllerModel import get_discriminators
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import get_Imagedataset, asKey, get_dataset
from ModularUtils.FunctionsTraining import get_training_variables


def load_image_dataset(Exp, cur_hnodes):
    image_data_dict = {}
    for hnode, cur_mechs in cur_hnodes.items():
        image_data_dict={}
        # for dno in range(Exp.num_datasets):
        for dno, intv in enumerate(Exp.Data_intervs):
            all_compare_Var, compare_Var, intervened_Var, real_labels_vars = get_training_variables(Exp, cur_mechs, dno, {})

            # ---------load image dataset  intv0X0

            image_dataset = []
            # for dno in range(Exp.num_datasets):  # number of interventional datasets including observations for this specific hnode
            for mech in all_compare_Var:
                if set(mech) & set(Exp.image_labels) != set():
                    digit_images = get_Imagedataset(Exp, 0, "ImgYdigit1")
                    image_dataset.append(digit_images)
            if len(image_dataset):
                image_data_dict[asKey(Exp.Data_intervs[dno])] = torch.cat(image_dataset, 1).to(Exp.DEVICE)

    return image_data_dict


def load_label_dataset(Exp, image_data_dict, label_generators, cur_hnodes, bayes_graph=None):  #get all datasets despite any hnodes

    # for dno in range(Exp.num_datasets):
    dataset_dict = {}

    for dno, intv in enumerate(Exp.Data_intervs):
        all_compare_Var, compare_Var, intervened_Var, real_labels_vars = get_training_variables(Exp, Exp.label_names, dno, intv)

        # load datasets without images

        dataset_dict[asKey(Exp.Data_intervs[dno])]={}
        repdata_dict ={}
        # need change here too.
        each_dataset = []
        rep_dataset = []
        for label in real_labels_vars:
            # if label not in compare_Var:
            if label not in Exp.rep_labels:
                each_dataset.append(get_dataset(Exp, label, dno))
            # else:
        # for hnode, cur_mechs in cur_hnodes.items():
        #     all_compare_Var, compare_Var, intervened_Var, real_labels_vars = get_training_variables(Exp, cur_mechs, dno, {})

        # ---- Load latent representaiton ----#
        for rep in Exp.rep_labels:

                parent= Exp.Observed_DAG[rep][1]
                y_discrete=get_dataset(Exp, parent, dno)
                dim_list = [Exp.label_dim[parent]]
                label_onehots = get_multiple_labels_fill(Exp, y_discrete.view(-1, 1), dim_list, isImage_labels=False,
                                                                          )

                image_values= image_data_dict[asKey({})]
                image_latents = label_generators[rep](Exp, image_values, label_onehots , dim_list, isOnehot=False, isLatent=True)

                # out_image = label_generators[rep](Exp, image_values[0:1], label_onehots[0:1], dim_list, isOnehot=False, isLatent=False)  #printing one image
                # img = out_image[0].permute(1, 2, 0).detach().cpu().numpy()
                # plot_trained_digits(1, 1, [img], f'Real')

                rep_dataset.append(image_latents)
                dataset_dict[asKey(Exp.Data_intervs[dno])]["rep"] = torch.cat(rep_dataset, 1).to(Exp.DEVICE)
        # ----  x -------

        dataset_dict[asKey(Exp.Data_intervs[dno])]["obs"] = torch.cat(each_dataset, 1).to(Exp.DEVICE)

        if dno==0 and bayes_graph!=None:
            Exp.bayesNet = prepare_bn(Exp, bayes_graph, dataset_dict[asKey(Exp.Data_intervs[dno])]["obs"], load_scm=1)
            print(Exp.bayesNet.cpt('Y'))

        # break  #for only observational data

    return dataset_dict

def get_intv_dataset(Exp):
    intv_dataset_list = []
    for dno in range(1, Exp.num_datasets):
        dataset = []
        for label in Exp.label_names:
            file_name = Exp.file_roots[dno] + label + ".pkl"
            with open(file_name, 'rb') as fp:
                label_data = pickle.load(fp)
            label_data = torch.FloatTensor(label_data)
            label_size = len(label_data)
            # plot_labels("intv "+label, label_data.view(label_size, 1))
            dataset.append(label_data.view(label_size, 1))
        dataset = torch.cat(dataset, 1).to(Exp.DEVICE)

        intv_dataset_list.append(dataset)

    intv_dataset = None
    if len(intv_dataset_list) != 0:
        intv_dataset = torch.cat(intv_dataset_list, 0)

    return intv_dataset

def save_partial_training_journey(Exp):
    file_name = Exp.LOAD_MODEL_PATH + "/model_journeys.txt"  # if not previous model, then it saves its own.
    journey_dict = {"journeys": []}
    if os.path.exists(file_name):
        with open(file_name) as f:
            data = f.read()
        journey_dict = json.loads(data)

    mech_str= "".join(x for x in cur_mechs)
    journey_dict["journeys"].append({"cur_mech":mech_str, "path":Exp.SAVED_PATH})
    new_file = Exp.SAVED_PATH + "/model_journeys.txt"
    with open(new_file, 'w') as fp:
        fp.write(json.dumps(journey_dict))

def initialize_results(Exp, cur_hnodes):
    tvd_diff = {}
    kl_diff = {}
    # for hn, cur_mechs in cur_hnodes.items():
    #
    #     tvd_diff = {}
    #     kl_diff = {}
    #
    #     cur_mechs= [lb for lb in cur_mechs if lb not in Exp.image_labels]
    #     if len(cur_mechs)==0:
    #         continue
    #
    #     query= getdoKey(cur_mechs, {})
    #     tvd_diff[query] = []
    #     kl_diff[query] = []

        # all_var= copy.deepcopy(cur_mechs)
        # if "X1" in all_var:
        #     all_var.remove("X1")
        # query = getdoKey(all_var, {"X1":0})
        # tvd_diff[hn][query] = []
        # kl_diff[hn][query] = []
        #
        # query = getdoKey(all_var, {"X1": 1})
        # tvd_diff[hn][query] = []
        # kl_diff[hn][query] = []

    # compare_Var = [lb for lb in Exp.label_names if lb not in Exp.image_labels+Exp.rep_labels]
    # obs_query= getdoKey(compare_Var, {})
    # tvd_diff[obs_query] = []
    # kl_diff[obs_query] = []
        #
    for query_list in Exp.interv_queries:

        for intv in query_list["intervs"]:
            query = getdoKey(query_list["obs"], intv)
            tvd_diff[query] = []
            kl_diff[query] = []

        # tvd_diff[query_list["expr"]] = []
        # kl_diff[query_list["expr"]] = []

        #
    for query in Exp.cf_queries:
        tvd_diff[query["expr"]] = []
        kl_diff[query["expr"]] = []

        # KeyError: 'P(X1X2WYdigit1Ydigit2YcolorYthick|do_[])'

    # if True in Exp.load_which_models.values() or train_no>0 :
    if True in Exp.load_which_models.values() :
        print("loading previous tvd diffs")
        for dist in tvd_diff:
            if os.path.exists(Exp.LOAD_MODEL_PATH + "/tvd/" + dist):
                tvd_diff[dist] = torch.load(Exp.LOAD_MODEL_PATH + "/tvd/" + dist).tolist()
                kl_diff[dist] = torch.load(Exp.LOAD_MODEL_PATH + "/kl/" + dist).tolist()


    return tvd_diff, kl_diff



def train_CausalController(Exp, cur_mechs, label_generators, G_optimizers, label_discriminator, D_optimizer,
                           dataset_dict_batches, batchno):

    G_loss=torch.zeros(1).to(Exp.DEVICE)
    for interv_no, (intv_key, dataset_batches) in enumerate(dataset_dict_batches.items()):
        intv_key = dict(intv_key)

        data_input = dataset_batches["obs"][batchno]

        _,_,_, graph_label_vars = get_training_variables(Exp, Exp.label_names, interv_no, intv_key)
        all_compare_Var, compare_Var, intervened_Var, real_labels_vars = get_training_variables(Exp, cur_mechs, interv_no, intv_key)

        #fix it later
        if len(real_labels_vars)>data_input.shape[1]:
            continue

        mini_batch = data_input.size()[0]
        indices = [graph_label_vars.index(lb) for lb in real_labels_vars]
        current_real_label = data_input[:, indices].type(torch.LongTensor).view(-1, len(indices)).to(Exp.DEVICE)
        # current_real_label = data_input.type(torch.LongTensor).view(-1, len(real_labels_vars)).to(Exp.DEVICE)
        dims_list = [Exp.label_dim[lb] for lb in real_labels_vars]

        obs_images=None
        if set(cur_mechs) & set(Exp.image_labels) != set():
            real_labels_fill = get_multiple_labels_fill(Exp, current_real_label, dims_list, isImage_labels=True, more_dimsize=Exp.IMAGE_SIZE)
            obs_images = dataset_batches["img"][batchno]
        elif set(all_compare_Var) & set(Exp.rep_labels) != set():
            real_labels_fill = get_multiple_labels_fill(Exp, current_real_label, dims_list, isImage_labels=False)
            real_labels_fill = torch.cat([real_labels_fill, dataset_batches["rep"][batchno]], 1)
        else:
            real_labels_fill = get_multiple_labels_fill(Exp, current_real_label, dims_list, isImage_labels=False)


        intv_tensor_dict = {}
        isClassifier=False
        for lbid, intv_lb in enumerate(intervened_Var): #if no intervention then no looping
            # index = [Exp.label_names.index(intv_lb)]
            # parent_intv_label = data_input[:, index].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE) #for each intv parent

            if intv_lb in Exp.image_labels:
                intv_parent_fill=obs_images
                isClassifier=True
            else:
                ind = real_labels_vars.index(intv_lb)
                parent_intv_label = data_input[:,ind].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE) #for each intv parent
                dims_list = [Exp.label_dim[intv_lb]]
                intv_parent_fill = get_multiple_labels_fill(Exp, parent_intv_label, dims_list, isImage_labels=False)
            intv_tensor_dict[intv_lb] = intv_parent_fill


        generated_image=None
        if set(cur_mechs) & set(Exp.image_labels) != set():
            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict,
                                                         intervened_Var + all_compare_Var, mini_batch, hard=True)
            generated_image = generated_labels_dict[Exp.image_labels[0]]
            del generated_labels_dict[Exp.image_labels[0]]
            y_dims = sum([Exp.label_dim[lb] for lb in real_labels_vars])
            ret = list(generated_labels_dict.values())
            ret2d = torch.cat(ret, 1).view(-1, y_dims) #for critic
            generated_labels_fill= fill2d_to_fill4d(Exp, ret2d, more_dimsize=Exp.IMAGE_SIZE)

        elif set(all_compare_Var) & set(Exp.rep_labels) != set():
            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict,real_labels_vars+Exp.rep_labels, mini_batch)
            y_dims = sum([Exp.label_dim[lb] for lb in real_labels_vars+Exp.rep_labels])
            ret = list(generated_labels_dict.values())
            generated_labels_fill = torch.cat(ret, 1).view(-1, y_dims)
        else:
            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_tensor_dict, real_labels_vars, mini_batch)
            y_dims = sum([Exp.label_dim[lb] for lb in real_labels_vars])
            ret = list(generated_labels_dict.values())
            generated_labels_fill = torch.cat(ret, 1).view(-1, y_dims)


        D_losses = []
        image_loss=[]
        label_loss=[]
        for crit_ in range(Exp.CRITIC_ITERATIONS):

            if Exp.test_marginals:
                D_real_image = label_discriminator[-2](obs_images, real_labels_fill[:,0:10,:,:]).squeeze()
                D_fake_image = label_discriminator[-2](generated_image, generated_labels_fill[:,0:10,:,:]).squeeze()
                gp_image = labels_image_gradient_penalty(label_discriminator[-2], obs_images, real_labels_fill[:,0:10,:,:],
                                                       generated_image, generated_labels_fill[:,0:10,:,:], isClassifier,
                                                       device=Exp.DEVICE)
                D_loss_image = (-  (torch.mean(D_real_image) - torch.mean(D_fake_image)) + Exp.LAMBDA_GP * gp_image)
                image_loss.append((D_loss_image).data)


                real1 = data_input[:,1].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE)
                real_color = get_multiple_labels_fill(Exp, real1, [4], isImage_labels=False)
                fake_color  = generated_labels_dict["C"]
                D_real_color = label_discriminator[-1](real_color).squeeze()
                D_fake_color = label_discriminator[-1](fake_color).squeeze()
                gp_labels = calc_gradient_penalty(label_discriminator[-1], real_color, fake_color, device=Exp.DEVICE)
                D_loss_labels_temp = (-  (torch.mean(D_real_color) - torch.mean(D_fake_color)) )
                D_loss_labels = D_loss_labels_temp + Exp.LAMBDA_GP * gp_labels
                label_loss.append((D_loss_labels).data)



            if set(cur_mechs) & set(Exp.image_labels) != set():
                D_real_decision_obs = label_discriminator[interv_no](obs_images, real_labels_fill).squeeze()
                D_fake_decision_obs = label_discriminator[interv_no](generated_image, generated_labels_fill).squeeze()
                gp_obs = labels_image_gradient_penalty(label_discriminator[interv_no], obs_images, real_labels_fill, generated_image, generated_labels_fill, isClassifier,
                                           device=Exp.DEVICE)
            else:
                D_real_decision_obs = label_discriminator[interv_no](real_labels_fill).squeeze()
                D_fake_decision_obs = label_discriminator[interv_no](generated_labels_fill).squeeze()
                gp_obs = calc_gradient_penalty(label_discriminator[interv_no], real_labels_fill, generated_labels_fill, device=Exp.DEVICE)

            D_loss_obs = (-  (torch.mean(D_real_decision_obs) - torch.mean(D_fake_decision_obs)) + Exp.LAMBDA_GP * gp_obs)


            #THINK ABOUT separate loss for images and labels
            D_losses.append((D_loss_obs).data)  # just a loss list

            label_discriminator[interv_no].zero_grad()
            # gp_obs.backward(retain_graph=True)
            D_loss_obs.backward(retain_graph=True)
            D_optimizer[interv_no].step()

        # accumulating the generator losses for all interventions.
        if set(cur_mechs) & set(Exp.image_labels) != set():
            D_fake_decision_obs = label_discriminator[interv_no](generated_image, generated_labels_fill).squeeze()
        else:
            D_fake_decision_obs = label_discriminator[interv_no](generated_labels_fill).squeeze()

        G_loss += -torch.mean(D_fake_decision_obs)


    # Back propagation
    for mech in cur_mechs:
        label_generators[mech].zero_grad()

    G_loss.backward()

    for mech in cur_mechs:
        G_optimizers[mech].step()

    D_loss = torch.mean(torch.FloatTensor(D_losses))  # just mean of losses

    if Exp.test_marginals:
        image_loss = torch.mean(torch.FloatTensor(image_loss))  # just mean of losses
        label_loss = torch.mean(torch.FloatTensor(label_loss))  # just mean of losses
    else:
        image_loss, label_loss = torch.zeros(1).to(Exp.DEVICE), torch.zeros(1).to(Exp.DEVICE)
    # if (Exp.curr_epoochs) % 5 == 0:
    #     print("real label", current_real_label[0])
    #     genimg = obs_images[0].permute(1, 2, 0).detach().cpu().numpy()
    #     plot_trained_digits(1, 1, [genimg], f'Real {real_labels_fill[0]}')
    #
    #     print("Fake label", generated_labels_dict["D"][0])
    #     genimg = generated_image[0].permute(1, 2, 0).detach().cpu().numpy()
    #     plot_trained_digits(1, 1, [genimg], f'fake {generated_labels_fill[0]}')

    return G_loss.data, D_loss.data, image_loss, label_loss




def labelMain(Exp, cur_hnodes, label_generators, G_optimizers, discriminators, D_optimizers, dataset_dict,
              tvd_diff, kl_diff):
    dataset_dict_batches = {}

    num_batches=0
    for key, each_dataset in dataset_dict.items():
        dataset_dict_batches[key]={}
        real_dataloader = torch.utils.data.DataLoader(dataset=each_dataset["obs"],
                                                      batch_size=Exp.batch_size,
                                                      shuffle=False)

        batch_list = []
        for data_input in real_dataloader:
            data_input = torch.squeeze(data_input)
            if len(data_input.size())==1:
                data_input= data_input.view(-1,1)
            batch_list.append(data_input)

        dataset_dict_batches[key]["obs"] = batch_list
        num_batches = len(batch_list)

        ####
        if len(Exp.rep_labels):
            real_dataloader = torch.utils.data.DataLoader(dataset=each_dataset["rep"],
                                                          batch_size=Exp.batch_size,
                                                          shuffle=False)

            batch_list = []
            for data_input in real_dataloader:
                data_input = torch.squeeze(data_input)
                if len(data_input.size())==1:
                    data_input= data_input.view(-1,1)
                batch_list.append(data_input)

            dataset_dict_batches[key]["rep"] = batch_list
            num_batches = len(batch_list)

    # for images
    # imagedata_dict_batches = {}
    # for key, each_dataset in dataset_dict["obs"]["img"].items():
        if len(Exp.image_labels):
            image_data_loader = torch.utils.data.DataLoader(dataset=each_dataset["img"],
                                                            batch_size=Exp.batch_size,
                                                            shuffle=False)
            batch_list = []
            for data_input in image_data_loader:
                data_input = torch.squeeze(data_input)
                batch_list.append(data_input)
            dataset_dict_batches[key]["img"] = batch_list

    iteration = 0

    for batchno in range(num_batches):

        for hn, cur_mechs in cur_hnodes.items():

            g_loss, d_loss,  image_loss, label_loss = train_CausalController(Exp, cur_mechs, label_generators, G_optimizers, discriminators[hn],
                                                    D_optimizers[hn], dataset_dict_batches, batchno)

            print('Epoch [%d/%d], Step [%d/%d],' % (
                Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, num_batches),
              'mechanism: ',cur_mechs,  ' D_loss: %.4f, G_loss: %.4f' % (d_loss.data, g_loss.data))



        # Annealing
        tot_iter = Exp.curr_epoochs * num_batches + iteration
        if tot_iter % 100 == 0:
            Exp.anneal_temperature(tot_iter)

        # if (iteration + 1) % int(num_batches / Exp.PLOTS_PER_EPOCH) == 0:


        Exp.D_avg_losses.append(torch.mean(d_loss))
        Exp.G_avg_losses.append(torch.mean(g_loss))
        iteration += 1

        break
    #
    if (Exp.curr_epoochs + 1) % 1 == 0:
        print("Turn on caffeinate or these results are gone!")
        # tvd_diff, kl_diff = trainByCompEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
        # tvd_diff, kl_diff = transportEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
        # tvd_diff, kl_diff = frontdoorEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
        # tvd_diff, kl_diff = backdoorEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
        # tvd_diff, kl_diff = sachsEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
        # tvd_diff, kl_diff = alignmentEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
        tvd_diff, kl_diff = imageMediatorEvaluation(Exp, cur_hnodes, label_generators, dataset_dict, tvd_diff, kl_diff)
#
    # if (Exp.curr_epoochs <= 50 and (Exp.curr_epoochs + 1) % 5 == 0) or (Exp.curr_epoochs > 50 and (Exp.curr_epoochs + 1) % 15 == 0):
    if (Exp.curr_epoochs + 1) % 5 == 0:
        var_list= "".join(x for x in cur_mechs)
        save_checkpoint(Exp, Exp.SAVED_PATH, cur_mechs, label_generators, G_optimizers, {var_list:discriminators}, {var_list: D_optimizers})
        print(Exp.curr_epoochs,":model saved at ", Exp.SAVED_PATH)
    #
    # compare_Var = Exp.train_mech_dict[cur_mech][0]["compare"]
    # intv_key = Exp.train_mech_dict[cur_mech][0]["intv"]
    # query = getdoKey(compare_Var, intv_key)
    #
    # return tvd_diff[query][-1]
    return 100





if __name__ == "__main__":

    # temp, dlayer, gp
    Exp = Experiment("Exp1", set_mnist_nonId_newgraph,
                     dist_thresh=0.15,
                     causal_hierarchy=2,
                     Temperature=1,
                     temp_min=0.1,
                     G_hid_dims=[256, 256],
                     D_hid_dims=[256, 256, 256],
                     IMAGE_FILTERS=[128, 64, 32],
                     CRITIC_ITERATIONS=5,
                     LAMBDA_GP=10,
                     learning_rate=2 * 1e-4,
                     Synthetic_Sample_Size=40000,
                     intv_Sample_Size=40000,
                     batch_size=200,
                     features=["feature"],
                     noise_states=100,
                     latent_state=16,
                     Data_intervs=[{}, {"X1":0}, {"X1":1}],
                     num_epochs=300,
                     new_experiment=True
                     )


    print(Exp.Data_intervs)
    Exp.intv_batch_size = Exp.batch_size
    # True scm

    os.makedirs(Exp.SAVED_PATH, exist_ok=True)
    dag_name = Exp.Complete_DAG_desc + ".txt"

    # Load previous model results also

    Exp.LOAD_MODEL_PATH = "SAVED_EXPERIMENTS/mnist_nonId_newgraph/Exp1/Mon_DD_YYYY-HH_MM"  #Ex: Sep_20_2022-13_52
    Exp.load_which_models = {"X1": False, "X2": False, "W": False, "Ydigit1": False, "Ydigit2": False, "Ycolor": False,
                                                      "Ythick": False}
    # Exp.load_which_models = {"X1": True, "X2": True, "W": True, "Ydigit1": True, "Ydigit2": True, "Ycolor": True,
    #                          "Ythick": True}


    c_components= [{"num_dataset":3, "cur_mechs" : ["X1", "X2", "W", "Ycolor"]},
                   {"num_dataset":1, "cur_mechs" : ["Ydigit1", "Ydigit2", "Ythick"]}]

    comp_no=1
    # for train_no, each_com in enumerate(c_components):
    ##############****************##############
    each_com = c_components[comp_no]
    cur_mechs = each_com["cur_mechs"]
    Exp.num_datasets = each_com["num_dataset"]


    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)



    discriminatorsMech, doptimizersMech = get_discriminators(Exp, cur_mechs, Exp.load_which_models)  #



    image_data_dict = load_image_dataset(Exp, each_com['cur_mechs'])
    dataset_dict = load_label_dataset(Exp, image_data_dict, label_generators, each_com['cur_mechs'])
    dataset_dict[asKey({})]["img"]=image_data_dict[asKey({})]


    # load datasets without images
    dataset_dict = {}
    for dno in range(Exp.num_datasets):
        each_dataset = []
        for label in Exp.label_names:
            if label not in Exp.image_labels:
                each_dataset.append(get_dataset(Exp, label, dno))
        dataset_dict[asKey(Exp.Data_intervs[dno])] = torch.cat(each_dataset, 1).to(Exp.DEVICE)



    initialize_results(Exp)


    mech_tvd = 0
    print("Starting training new mechanism")


    for epoch in range(Exp.num_epochs):
        Exp.curr_epoochs = epoch
        mech_tvd = labelMain(Exp, cur_mechs, label_generators, optimizersMech, discriminatorsMech, doptimizersMech, dataset_dict, image_data_dict, tvd_diff, kl_diff)
